--- external/eg3d/training/loss.py	2023-04-06 03:55:57.421400822 +0000
+++ external_reference/eg3d/training/loss.py	2023-04-06 03:41:04.420707774 +0000
@@ -15,7 +15,7 @@
 from torch_utils import training_stats
 from torch_utils.ops import conv2d_gradfix
 from torch_utils.ops import upfirdn2d
-from training.dual_discriminator import filtered_resizing
+from external.eg3d.training.dual_discriminator import filtered_resizing
 
 #----------------------------------------------------------------------------
 
@@ -26,7 +26,17 @@
 #----------------------------------------------------------------------------
 
 class StyleGAN2Loss(Loss):
-    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
+    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10,
+                 style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2,
+                 pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0,
+                 blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0,
+                 neural_rendering_resolution_initial=64,
+                 neural_rendering_resolution_final=None,
+                 neural_rendering_resolution_fade_kimg=0,
+                 gpc_reg_fade_kimg=1000, gpc_reg_prob=None,
+                 dual_discrimination=False, filter_mode='antialiased', 
+                 ignore_LR_disp=False, ignore_HR_disp=True,
+                 lambda_sky_pixel=0., lambda_ramp_end=0):
         super().__init__()
         self.device             = device
         self.G                  = G
@@ -52,6 +62,10 @@
         self.filter_mode = filter_mode
         self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
         self.blur_raw_target = True
+        self.ignore_LR_disp = ignore_LR_disp
+        self.ignore_HR_disp = ignore_HR_disp
+        self.lambda_sky_pixel = lambda_sky_pixel
+        self.lambda_ramp_end = lambda_ramp_end
         assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)
 
     def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
@@ -84,6 +98,21 @@
             img['image'] = augmented_pair[:, :img['image'].shape[1]]
             img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)
 
+        if self.ignore_LR_disp:
+            # mask LR disp channel to discriminator
+            dummy = torch.zeros_like(img['image_raw'])
+            img['image_raw'] = torch.cat([img['image_raw'][:, :3],
+                                          dummy[:, :1],
+                                          img['image_raw'][:, 4:]
+                                         ], dim=1)
+        if self.ignore_HR_disp:
+            # mask HR disp channel to discriminator
+            dummy = torch.zeros_like(img['image'])
+            img['image'] = torch.cat([img['image'][:, :3],
+                                      dummy[:, :1],
+                                      img['image'][:, 4:]
+                                     ], dim=1)
+
         logits = self.D(img, c, update_emas=update_emas)
         return logits
 
@@ -124,6 +153,21 @@
                 training_stats.report('Loss/signs/fake', gen_logits.sign())
                 loss_Gmain = torch.nn.functional.softplus(-gen_logits)
                 training_stats.report('Loss/G/loss', loss_Gmain)
+                if self.lambda_sky_pixel > 0:
+                    if self.lambda_ramp_end > 0:
+                        ramp_multiplier = cur_nimg / self.lambda_ramp_end
+                        ramp_multiplier = np.clip(ramp_multiplier, 0, 1)
+                        training_stats.report('Loss/G/reg_ramp', ramp_multiplier)
+                    else:
+                        ramp_multiplier = 1
+                    weights_detach = gen_img['weights_raw'].detach()
+                    pixel_sum = gen_img['image_raw'][:, :3].sum(1, keepdim=True)
+                    # add penalty on white pixels when weights > 0
+                    penalty = torch.exp(5*(pixel_sum-3)) * weights_detach
+                    sky_penalty = penalty.mean(dim=(2,3)) # Nx1
+                    loss_Gmain = (loss_Gmain + ramp_multiplier * self.lambda_sky_pixel * sky_penalty)
+                    training_stats.report('Loss/G/loss_sky_pixel', sky_penalty)
+                training_stats.report('Loss/G/loss_total', loss_Gmain)
             with torch.autograd.profiler.record_function('Gmain_backward'):
                 loss_Gmain.mean().mul(gain).backward()
 
